# -*- coding: UTF-8 -*-
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
from functools import partial

import paddle
from paddlenlp.data import Stack, Tuple, Pad, Vocab
from paddlenlp.transformers import ErnieTokenizer

sys.path.append('../..')
from pycorrector.ernie_csc.utils import convert_example, parse_decode


class Predictor:
    def __init__(self, model_file, params_file, device, max_seq_length,
                 tokenizer, pinyin_vocab):
        self.max_seq_length = max_seq_length

        config = paddle.inference.Config(model_file, params_file)
        if device == "gpu":
            # set GPU configs accordingly
            config.enable_use_gpu(100, 0)
        elif device == "cpu":
            # set CPU configs accordingly,
            # such as enable_mkldnn, set_cpu_math_library_num_threads
            config.disable_gpu()
        config.switch_use_feed_fetch_ops(False)
        self.predictor = paddle.inference.create_predictor(config)

        self.input_handles = [
            self.predictor.get_input_handle(name)
            for name in self.predictor.get_input_names()
        ]

        self.det_error_probs_handle = self.predictor.get_output_handle(
            self.predictor.get_output_names()[0])
        self.corr_logits_handle = self.predictor.get_output_handle(
            self.predictor.get_output_names()[1])
        self.tokenizer = tokenizer
        self.pinyin_vocab = pinyin_vocab

    def predict(self, sentences, batch_size=1):
        """
        Predicts the data labels.

        Args:
            sentences (obj:`List(Example)`): The processed data and each element is a example (numedtuple) object.
                A Example object contains `text`(word_ids) and `seq_len`(sequence length).
            batch_size(obj:`int`, defaults to 1): The number of batch.

        Returns:
            results(obj:`dict`): All the predictions labels.
        """
        examples = []
        texts = []
        trans_func = partial(
            convert_example,
            tokenizer=self.tokenizer,
            pinyin_vocab=self.pinyin_vocab,
            max_seq_length=self.max_seq_length,
            is_test=True)

        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id, dtype='int64'),  # input
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id, dtype='int64'),  # segment
            Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token], dtype='int64'),  # pinyin
            Stack(axis=0, dtype='int64'),  # length
        ): [data for data in fn(samples)]

        for text in sentences:
            example = {"source": text.strip()}
            input_ids, token_type_ids, pinyin_ids, length = trans_func(example)
            examples.append((input_ids, token_type_ids, pinyin_ids, length))
            texts.append(example["source"])

        batch_examples = [
            examples[idx:idx + batch_size]
            for idx in range(0, len(examples), batch_size)
        ]
        batch_texts = [
            texts[idx:idx + batch_size]
            for idx in range(0, len(examples), batch_size)
        ]
        results = []

        for examples, texts in zip(batch_examples, batch_texts):
            token_ids, token_type_ids, pinyin_ids, length = batchify_fn(
                examples)
            self.input_handles[0].copy_from_cpu(token_ids)
            self.input_handles[1].copy_from_cpu(pinyin_ids)
            self.predictor.run()
            det_error_probs = self.det_error_probs_handle.copy_to_cpu()
            corr_logits = self.corr_logits_handle.copy_to_cpu()

            det_pred = det_error_probs.argmax(axis=-1)
            char_preds = corr_logits.argmax(axis=-1)

            for i in range(len(length)):
                pred_result = parse_decode(texts[i], char_preds[i], det_pred[i],
                                           length[i], self.tokenizer,
                                           self.max_seq_length)

                results.append(''.join(pred_result))
        return results


if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_file", type=str, required=True, default='./static_graph_params.pdmodel',
                        help="The path to model info in static graph.")
    parser.add_argument("--params_file", type=str, required=True, default='./static_graph_params.pdiparams',
                        help="The path to parameters in static graph.")
    parser.add_argument("--batch_size", type=int, default=4, help="The number of sequences contained in a mini-batch.")
    parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
    parser.add_argument("--device", default="gpu", type=str, choices=["cpu", "gpu"],
                        help="The device to select to train the model, is must be cpu/gpu.")
    parser.add_argument("--pinyin_vocab_file_path", type=str, default="pinyin_vocab.txt", help="pinyin vocab file path")

    args = parser.parse_args()
    # yapf: enable

    tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
    pinyin_vocab = Vocab.load_vocabulary(
        args.pinyin_vocab_file_path, unk_token='[UNK]', pad_token='[PAD]')
    predictor = Predictor(args.model_file, args.params_file, args.device,
                          args.max_seq_len, tokenizer, pinyin_vocab)

    samples = [
        '遇到逆竟时，我们必须勇于面对，而且要愈挫愈勇，这样我们才能朝著成功之路前进。',
        '人生就是如此，经过磨练才能让自己更加拙壮，才能使自己更加乐观。',
    ]

    results = predictor.predict(samples, batch_size=args.batch_size)
    for source, target in zip(samples, results):
        print("Source:", source)
        print("Target:", target)
